import torch
import imageio
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt


def traversal_latents(base_latent, traversal_vector, dim):
    l = len(traversal_vector)
    traversals = base_latent.repeat(l, 1)
    traversals[:, dim] = traversal_vector
    return traversals


def rand_sample(ds, sz=1):
    rand_order = torch.randperm(len(ds))
    return ds[rand_order[:sz]]


def plot_bar(axes, images, label=None):
    for ax, img in zip(axes, images):
        ax.imshow(img.cpu().numpy(), cmap='gray')
        ax.axis('off')

    if label:
        axes[-1].get_yaxis().set_label_position("right")
        axes[-1].set_ylabel(label)


def plot_reconstruct(dataset, size, model):
    r, c = size
    with torch.no_grad():
        img = rand_sample(dataset, r * c)
        recon_batch, latent_dist, latent_sample = model(img)

        fig, axes = plt.subplots(r, c * 2, figsize=(c * 2 * 2, r * 2))
        plt.tight_layout()
        for i in range(c):
            plot_bar(axes[:, i * c], img[r * i:r * i + r, 0])
            axes[0, i * c].set_title('origin')
            plot_bar(axes[:, i * c + 1], recon_batch[r * i:r * i + r, 0])
            axes[0, i * c + 1].set_title('recon')

        return fig


def plot_rand_reconstruct(size, model):
    r, c = size
    with torch.no_grad():
        latents_values = torch.randn(r * c, model.latent_dim)
        recon_batch = model.decoder(latents_values)

        fig, axes = plt.subplots(r, c, figsize=(c * 4, r * 4))
        plt.tight_layout()
        for img, ax in zip(recon_batch, axes.flatten()):
            ax.imshow(img.reshape(r, c), cmap='gray')
            ax.axis('off')
        return fig


def plt_sample_traversal(sample, model, traversal_len=5, dim_list=range(4), r=3):
    dim_len = len(dim_list)
    with torch.no_grad():
        if sample is not None:
            mu, _ = model.encoder(sample)
        else:
            mu = torch.zeros(1, dim_len)
        fig, axes = plt.subplots(dim_len, traversal_len, squeeze=False,
                                 figsize=(traversal_len, dim_len))
        axes = axes.reshape(dim_len, traversal_len)
        plt.tight_layout(0.1)

        for i, dim in enumerate(dim_list):
            base_latents = mu
            linear_traversal = torch.linspace(-r, r, traversal_len)
            traversals = traversal_latents(base_latents, linear_traversal, dim)
            recon_batch = model.decoder(traversals)

            plot_bar(axes[i], recon_batch[:, 0])

        plt.subplots_adjust(wspace=0.05, hspace=0.05)
        return fig


cm = plt.cm.get_cmap('RdYlBu')


def plot_trajectory(sample, model):
    fig = plt.figure()
    with torch.no_grad():
        mu, log_var = model.encoder(sample.view(-1, 1, 64, 64))
        mu = mu.cpu()
        sc = plt.scatter(mu[:32, 0].data, mu[:32, 1].data, c=range(32), )
        plt.colorbar(sc)
        sc = plt.scatter(mu[::32, 0].data, mu[::32, 1].data, c=range(32), cmap=cm)
        plt.colorbar(sc)
    return fig


def plot_projection(samples, model, dim=(0, 1), labels=None):
    fig = plt.figure(figsize=(5, 5), dpi=300)
    z1, z2, _, _ = samples.shape

    with torch.no_grad():
        mu, log_var = model.encoder(samples.reshape(-1, 1, 64, 64))
        mu = mu.cpu()
        plt.scatter(mu[:, dim[0]].data, mu[:, dim[1]].data, )
        for j in range(z1):
            plt.plot(mu[z2 * j:z2 * j + z2, dim[0]], mu[z2 * j:z2 * j + z2, dim[1]])

        for j in range(z1):
            for i in range(z2):
                index = i + z2 * j
                if labels is None:
                    plt.text(mu[index, dim[0]], mu[index, dim[1]], f'({j},{i})')
                else:
                    pos = labels[j, i]
                    plt.text(mu[index, dim[0]], mu[index, dim[1]], f'{pos[0]:.1f},{pos[1]:.1f}', size=6)

    return fig


def gt_vs_latent(variables, labels):
    fig, axes = plt.subplots(labels.size(1), variables.size(1), squeeze=False)
    # plt.tight_layout()
    for i in range(labels.size(1)):  # factor
        for j in range(variables.size(1)):  # variable
            axes[i, j].scatter(variables[:, j].numpy(), labels[:, i].numpy(), s=0.2)
    return fig
